{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# K-Means Clustering Tutorial\n",
    "\n",
    "This guide will show how to use one of Tribuo’s clustering models to find clusters in a toy dataset drawn from a mixture of Gaussians. We'll look at Tribuo's K-Means implementation and also discuss how evaluation works for clustering tasks.\n",
    "\n",
    "## Setup\n",
    "\n",
    "We'll load in some jars and import a few packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%jars ./tribuo-clustering-kmeans-4.3.0-jar-with-dependencies.jar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.tribuo.*;\n",
    "import org.tribuo.util.Util;\n",
    "import org.tribuo.clustering.*;\n",
    "import org.tribuo.clustering.evaluation.*;\n",
    "import org.tribuo.clustering.example.GaussianClusterDataSource;\n",
    "import org.tribuo.clustering.kmeans.*;\n",
    "import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation;\n",
    "import org.tribuo.math.distance.DistanceType;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "var eval = new ClusteringEvaluator();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset\n",
    "Tribuo's clustering package comes with a simple data source that emits data sampled from a mixture of 5 2-dimensional Gaussians (the dimensionality of the Gaussians can be in the range 1 - 4, and the means & variances can also be set arbitrarily). This source sets the ground truth cluster IDs, so it can be used to measure clustering performance for demos like this. You can also use any of the standard data loaders to pull in clustering data.\n",
    "\n",
    "As it conforms to the standard `Trainer` and `Model` interface used for the rest of Tribuo, the training of a clustering algorithm doesn't produce cluster assignments that are visible, to recover the assignments we need to call `model.predict(trainData)`.\n",
    "\n",
    "We're going to sample two datasets (using different seeds) one for fitting the cluster centroids, and one to measure clustering performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "var data = new MutableDataset<>(new GaussianClusterDataSource(500, 1L));\n",
    "var test = new MutableDataset<>(new GaussianClusterDataSource(500, 2L));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The defaults for the data source are:\n",
    "1. `N([ 0.0,0.0], [[1.0,0.0],[0.0,1.0]])`\n",
    "2. `N([ 5.0,5.0], [[1.0,0.0],[0.0,1.0]])`\n",
    "3. `N([ 2.5,2.5], [[1.0,0.5],[0.5,1.0]])`\n",
    "4. `N([10.0,0.0], [[0.1,0.0],[0.0,0.1]])`\n",
    "5. `N([-1.0,0.0], [[1.0,0.0],[0.0,0.1]])`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Training\n",
    "We'll first fit a K-Means using 5 centroids, a maximum of 10 iterations, using the euclidean distance and a single computation thread."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training with 5 clusters took (00:00:00:071)\n"
     ]
    }
   ],
   "source": [
    "var l2Dist = DistanceType.L2.getDistance();\n",
    "var trainer = new KMeansTrainer(5, /* centroids */\n",
    "                                10, /* iterations */\n",
    "                                l2Dist, /* distance function */\n",
    "                                1, /* number of compute threads */\n",
    "                                1 /* RNG seed */\n",
    "                               );\n",
    "var startTime = System.currentTimeMillis();\n",
    "var model = trainer.train(data);\n",
    "var endTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training with 5 clusters took \" + Util.formatDuration(startTime,endTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can inspect the centroids by querying the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(A, -1.729407), (B, -0.019281)]\n",
      "[(A, 2.740410), (B, 2.873769)]\n",
      "[(A, 0.051021), (B, 0.075766)]\n",
      "[(A, 5.174978), (B, 5.088150)]\n",
      "[(A, 9.938804), (B, -0.020702)]\n"
     ]
    }
   ],
   "source": [
    "var centroids = model.getCentroids();\n",
    "for (var centroid : centroids) {\n",
    "    System.out.println(centroid);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These centroids line up pretty well with the Gaussian centroids. The predicted ones line up with the true ones as follows:\n",
    "\n",
    "|Predicted|True|\n",
    "|---|---|\n",
    "|1|5|\n",
    "|2|3|\n",
    "|3|1|\n",
    "|4|2|\n",
    "|5|4|\n",
    "\n",
    "Though the first one is a bit far out as it's \"A\" feature should be -1.0 not -1.7, and there is a little wobble in the rest. Still it's pretty good considering K-Means assumes spherical Gaussians and our data generator has a covariance matrix per Gaussian."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## K-Means++\n",
    "Tribuo also includes the K-Means++ initialisation algorithm, which we can run on our toy problem as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training with 5 clusters took (00:00:00:038)\n"
     ]
    }
   ],
   "source": [
    "var plusplusTrainer = new KMeansTrainer(5,10,l2Dist,Initialisation.PLUSPLUS,1,1);\n",
    "var startTime = System.currentTimeMillis();\n",
    "var plusplusModel = plusplusTrainer.train(data);\n",
    "var endTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training with 5 clusters took \" + Util.formatDuration(startTime,endTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training time isn't much different in this case, but the K-Means++ initialisation does take longer than the default on larger datasets. However the resulting clusters are usually better.\n",
    "\n",
    "We can check the centroids from this model using the same method as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(A, -1.567863), (B, -0.029534)]\n",
      "[(A, 9.938804), (B, -0.020702)]\n",
      "[(A, 3.876203), (B, 3.930657)]\n",
      "[(A, 0.399868), (B, 0.330537)]\n",
      "[(A, 5.520480), (B, 5.390406)]\n"
     ]
    }
   ],
   "source": [
    "var ppCentroids = plusplusModel.getCentroids();\n",
    "for (var centroid : ppCentroids) {\n",
    "    System.out.println(centroid);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see in this case that the K-Means++ initialisation has warped the centroids slightly, so the fit isn't quite as nice as the default initialisation, but that's why we have evaluation data and measure model fit. K-Means++ usually improves the fit of a K-Means clustering, but it might be too complicated for this simple toy dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model evaluation\n",
    "Tribuo uses the normalized mutual information to measure the quality\n",
    "of two clusterings. This avoids the issue that swapping the id number of any given centroid doesn't change the overall clustering. We're going to compare against the ground truth cluster labels from the data generator.\n",
    "\n",
    "First for the training data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Clustering Evaluation\n",
       "Normalized MI = 0.8128096132028937\n",
       "Adjusted MI = 0.8113314999600724"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var trainEvaluation = eval.evaluate(model,data);\n",
    "trainEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then for the unseen test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Clustering Evaluation\n",
       "Normalized MI = 0.8154291916732408\n",
       "Adjusted MI = 0.8139169341974347"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var testEvaluation = eval.evaluate(model,test);\n",
    "testEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that as expected it's a pretty good correlation to the ground truth labels. K-Means (of the kind implemented in Tribuo) is similar to a Gaussian mixture using spherical Gaussians, and our data generator uses Gaussians with full rank covariances, so it won't be perfect.\n",
    "\n",
    "We can also check the K-Means++ model in the same way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Clustering Evaluation\n",
       "Normalized MI = 0.7881995472105396\n",
       "Adjusted MI = 0.7864797287891137"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var testPlusPlusEvaluation = eval.evaluate(plusplusModel,test);\n",
    "testPlusPlusEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected with the slightly poorer quality centroids this initialisation gives then it's not got quite as good a fit. However we emphasise that K-Means++ usually improves the quality of the clustering, and so it's worth testing out if you're clustering data with Tribuo."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multithreading\n",
    "Tribuo's K-Means supports multi-threading of both the expectation and maximisation steps in the algorithm (i.e., the finding of the new centroids, and the assignment of points to centroids). We'll run the same experiment as before, both with 5 centroids and with 20 centroids, using 4 threads, though this time we'll use 2000 points for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training with 5 clusters on 4 threads took (00:00:00:053)\n"
     ]
    }
   ],
   "source": [
    "var mtData = new MutableDataset<>(new GaussianClusterDataSource(2000, 1L));\n",
    "var mtTrainer = new KMeansTrainer(5,10,l2Dist,4,1);\n",
    "var mtStartTime = System.currentTimeMillis();\n",
    "var mtModel = mtTrainer.train(mtData);\n",
    "var mtEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training with 5 clusters on 4 threads took \" + Util.formatDuration(mtStartTime,mtEndTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now with 20 centroids:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training with 20 clusters on 4 threads took (00:00:00:034)\n"
     ]
    }
   ],
   "source": [
    "var overTrainer = new KMeansTrainer(20,10,l2Dist,4,1);\n",
    "var overStartTime = System.currentTimeMillis();\n",
    "var overModel = overTrainer.train(mtData);\n",
    "var overEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training with 20 clusters on 4 threads took \" + Util.formatDuration(overStartTime,overEndTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can evaluate the two models as before, using our `ClusteringEvaluator`. First with 5 centroids:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Clustering Evaluation\n",
       "Normalized MI = 0.8104463467727057\n",
       "Adjusted MI = 0.8088941747417295"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var mtTestEvaluation = eval.evaluate(mtModel,test);\n",
    "mtTestEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then with 20:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Clustering Evaluation\n",
       "Normalized MI = 0.8647317143685641\n",
       "Adjusted MI = 0.8603214693630152"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var overTestEvaluation = eval.evaluate(overModel,test);\n",
    "overTestEvaluation.toString();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the multi-threaded versions run in about the same time as the single threaded trainer, but have 4 times the training data. The 20 centroid model has a tighter fit of the test data, though it is overparameterised. This is common in clustering tasks where it's hard to balance the model fitting with complexity. We'll look at adding more performance metrics so users can diagnose such issues in future releases. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "We looked at clustering using Tribuo's K-Means implementation, experimented with different initialisations, and compared both the single-threaded and multi-threaded versions. Then we looked at the performance metrics available when there are ground truth clusterings.\n",
    "\n",
    "We plan to further expand Tribuo's clustering functionality to incorporate other algorithms in the future, and added HDBSCAN in Tribuo v4.2. If you want to help, or have specific algorithmic requirements, file an issue on our [github page](https://github.com/oracle/tribuo)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "12+33"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
